from sklearn.datasets import load_svmlight_file
from Solver import Solver
import Logistic
import numpy
import matplotlib.pyplot as plt


class Demo:
    def __init__(self, CGmaxiter, repeat, gamma):
        self.CGmaxiter = CGmaxiter
        self.repeat = repeat
        self.gamma = gamma

    def fit(self, xMat, yVec, m):
        n, self.d = xMat.shape
        n = int(numpy.floor(n / m)) * m
        self.xMat = xMat[0:n, :]
        self.yVec = yVec[0:n].reshape(n, 1)
        self.n = n
        solver = Logistic.Solver(X=self.xMat, y=self.yVec)
        self.wopt, self.condnum = solver.newton(self.gamma)

    def testConvergence(self, m, s=None, isSearch=False, newtoniter=20):
        errMat = numpy.zeros((self.repeat, newtoniter+1))
        for r in range(self.repeat):
            print(str(r) + '-th repeat,   m = ' + str(m))
            solver = Solver(m=m, xMat=self.xMat, yVec=self.yVec)
            solver.fit(s=s)
            err = solver.train(self.gamma, self.wopt, CGmaxIter=self.CGmaxiter,
                               isSearch=isSearch, newtonMaxIter=newtoniter)
            del solver
            l = min(len(err), newtoniter+1)
            errMat[r, 0:l] = err[0:l]
        return errMat


def experiment(xMat, yVec, CGmaxiter, repeat, gamma, isSearch, s=None, newtoniter=20):
    demo = Demo(CGmaxiter, repeat, gamma)
    demo.fit(xMat, yVec, m=500)
    wopt = demo.wopt
    m = 500
    errMat = demo.testConvergence(
        m, s=s, isSearch=isSearch, newtoniter=newtoniter)

    return errMat, wopt


def plotConvergence(outfilename):
    npzfile = numpy.load(outfilename)
    dataname = str(npzfile['dataname'])
    err = npzfile['err']

    repeat = err.shape[0]

    # plot
    fig = plt.figure(figsize=(9, 8))

    for r in range(repeat):
        plt.plot(err[r, :], color='greenyellow',
                 linestyle='-', linewidth=0.5, alpha=0.5)

    line0, = plt.plot(numpy.median(err, axis=0),
                      color='g', linestyle='-', linewidth=3)

    #plt.legend([line0, line1, line2, line3], ['m=4', 'm=16', 'm=64', 'm=256'], fontsize=20)
    plt.xlabel('iterations (t)', fontsize=30)
    #plt.ylabel(r"$|| w - w^\star ||_2$", fontsize=28)
    plt.xticks([0, 5, 10, 15, 20, 25, 30], fontsize=28)
    plt.yticks(fontsize=28)
    plt.axis([0, 30, 1e-12, 5])
    plt.tight_layout()
    imagename = 'figures/demo_logis_' + dataname + '.jpg'
    fig.savefig(imagename, format='jpg', dpi=500)
    plt.show()


def loada9a(address='dataset/a9a'):
    X_sparse, Y = load_svmlight_file(address)
    X = X_sparse.todense()
    return numpy.array(X), numpy.array(Y)


if __name__ == '__main__':
    values = [(1, 93), (5, 13), (10, 3)]
    for T, q in values:
        CGMaxIter = q  # This is the q in figure 6
        Repeat = 20
        Gamma = 1e-6
        s = None  # This is the s in figure 6
        IsSearch = True
        NewtonIter = T  # This is the T in figure 6

        dataname = 'a9a'
        xMat, yVec = loada9a()

        print(xMat)
        print(yVec)

        errMat, wopt = experiment(
            xMat, yVec, CGMaxIter, Repeat, Gamma, IsSearch, s, newtoniter=NewtonIter)
        print(errMat)

        outfilename = f'results/demo_logis_{NewtonIter}_' + dataname + '.npz'
        numpy.savez(outfilename, err=errMat, dataname=dataname)
        numpy.save("wopt.npy", wopt)
